{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import jaccard_score\n",
    "from PIL import Image\n",
    "\n",
    "gt_color_to_label = {\n",
    "    (128,   0,   0) : 0,  # building\n",
    "    (128, 128, 128) : 1,  # sky\n",
    "    (128,  64, 128) : 2,  # road\n",
    "    (128, 128,   0) : 3,  # vegetation\n",
    "    (  0,   0, 192) : 4,  # sidewalk\n",
    "    ( 64,   0, 128) : 5,  # car\n",
    "    ( 64,  64,   0) : 6,  # pedestrain\n",
    "    (  0, 128, 192) : 7,  # cyclist\n",
    "    (192, 128, 128) : 8,  # signate\n",
    "    ( 64,  64, 128) : 9,  # fence\n",
    "    (192, 192, 128) : 10, # pole\n",
    "    (  0,   0,   0) : 255 # invalid\n",
    "}\n",
    "\n",
    "img_gt_dir = '/home/ganlu/bkism_ws/src/BKISemanticMapping/data/data_kitti_15/kitti_15/'\n",
    "img_pred_dir = '/home/ganlu/bkism_ws/src/BKISemanticMapping/data/data_kitti_15/reproj_img/'\n",
    "evaluation_list = '/home/ganlu/bkism_ws/src/BKISemanticMapping/data/data_kitti_15/evaluatioList.txt'\n",
    "evaluation_list = np.loadtxt(evaluation_list)\n",
    "\n",
    "img_pred_all = []\n",
    "img_gt_all = []\n",
    "for img_id in evaluation_list:\n",
    "    img_id = np.array2string(img_id, formatter={'float_kind':lambda x: \"%06i\" % x})\n",
    "    print(img_id)\n",
    "    \n",
    "    # Read images\n",
    "    img_gt_color = np.array(Image.open(img_gt_dir + img_id + '.png'))\n",
    "    img_pred = np.array(Image.open(img_pred_dir + img_id + '_bw.png'))\n",
    "\n",
    "    # Convert rgb to label\n",
    "    rows, cols, _ = img_gt_color.shape\n",
    "    img_gt = np.zeros((rows, cols))\n",
    "    for i in range(rows):\n",
    "        for j in range(cols):\n",
    "            img_gt[i, j] = gt_color_to_label[tuple(img_gt_color[i, j, :])]\n",
    "    \n",
    "    img_gt_all.append(img_gt)\n",
    "    img_pred_all.append(img_pred)\n",
    "    \n",
    "# Flatten\n",
    "img_gt_all = np.array(img_gt_all).flatten()\n",
    "img_pred_all = np.array(img_pred_all).flatten()\n",
    "\n",
    "# Ignore sky and invalid labels in gt\n",
    "img_pred_all = img_pred_all[img_gt_all != 1]\n",
    "img_gt_all = img_gt_all[img_gt_all != 1]\n",
    "img_pred_all = img_pred_all[img_gt_all != 255]\n",
    "img_gt_all = img_gt_all[img_gt_all != 255]\n",
    "\n",
    "# Ignore invalid labels in pred\n",
    "#img_gt_all = img_gt_all[img_pred_all != 255]\n",
    "#img_pred_all = img_pred_all[img_pred_all != 255]\n",
    "\n",
    "\n",
    "# IoU for each class\n",
    "print( np.unique(np.concatenate((img_gt_all, img_pred_all), axis=0)) )\n",
    "print( jaccard_score(img_gt_all, img_pred_all, average=None) )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
